"""Utilities related to determining the reachability of code (in semantic analysis)."""

from typing import Tuple, TypeVar, Union, Optional
from typing_extensions import Final

from mypy.nodes import (
    Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr,
    StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom,
    ImportAll, LITERAL_YES
)
from mypy.options import Options
from mypy.traverser import TraverserVisitor
from mypy.literals import literal

# Inferred truth value of an expression.
ALWAYS_TRUE = 1  # type: Final
MYPY_TRUE = 2  # type: Final  # True in mypy, False at runtime
ALWAYS_FALSE = 3  # type: Final
MYPY_FALSE = 4  # type: Final  # False in mypy, True at runtime
TRUTH_VALUE_UNKNOWN = 5  # type: Final

inverted_truth_mapping = {
    ALWAYS_TRUE: ALWAYS_FALSE,
    ALWAYS_FALSE: ALWAYS_TRUE,
    TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
    MYPY_TRUE: MYPY_FALSE,
    MYPY_FALSE: MYPY_TRUE,
}  # type: Final


def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
    for i in range(len(s.expr)):
        result = infer_condition_value(s.expr[i], options)
        if result in (ALWAYS_FALSE, MYPY_FALSE):
            # The condition is considered always false, so we skip the if/elif body.
            mark_block_unreachable(s.body[i])
        elif result in (ALWAYS_TRUE, MYPY_TRUE):
            # This condition is considered always true, so all of the remaining
            # elif/else bodies should not be checked.
            if result == MYPY_TRUE:
                # This condition is false at runtime; this will affect
                # import priorities.
                mark_block_mypy_only(s.body[i])
            for body in s.body[i + 1:]:
                mark_block_unreachable(body)

            # Make sure else body always exists and is marked as
            # unreachable so the type checker always knows that
            # all control flow paths will flow through the if
            # statement body.
            if not s.else_body:
                s.else_body = Block([])
            mark_block_unreachable(s.else_body)
            break


def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
    return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)


def infer_condition_value(expr: Expression, options: Options) -> int:
    """Infer whether the given condition is always true/false.

    Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
    MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
    false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
    """
    pyversion = options.python_version
    name = ''
    negated = False
    alias = expr
    if isinstance(alias, UnaryExpr):
        if alias.op == 'not':
            expr = alias.expr
            negated = True
    result = TRUTH_VALUE_UNKNOWN
    if isinstance(expr, NameExpr):
        name = expr.name
    elif isinstance(expr, MemberExpr):
        name = expr.name
    elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'):
        left = infer_condition_value(expr.left, options)
        if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or
                (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')):
            # Either `True and <other>` or `False or <other>`: the result will
            # always be the right-hand-side.
            return infer_condition_value(expr.right, options)
        else:
            # The result will always be the left-hand-side (e.g. ALWAYS_* or
            # TRUTH_VALUE_UNKNOWN).
            return left
    else:
        result = consider_sys_version_info(expr, pyversion)
        if result == TRUTH_VALUE_UNKNOWN:
            result = consider_sys_platform(expr, options.platform)
    if result == TRUTH_VALUE_UNKNOWN:
        if name == 'PY2':
            result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
        elif name == 'PY3':
            result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
        elif name == 'MYPY' or name == 'TYPE_CHECKING':
            result = MYPY_TRUE
        elif name in options.always_true:
            result = ALWAYS_TRUE
        elif name in options.always_false:
            result = ALWAYS_FALSE
    if negated:
        result = inverted_truth_mapping[result]
    return result


def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int:
    """Consider whether expr is a comparison involving sys.version_info.

    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
    """
    # Cases supported:
    # - sys.version_info[<int>] <compare_op> <int>
    # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
    # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
    #   (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
    if not isinstance(expr, ComparisonExpr):
        return TRUTH_VALUE_UNKNOWN
    # Let's not yet support chained comparisons.
    if len(expr.operators) > 1:
        return TRUTH_VALUE_UNKNOWN
    op = expr.operators[0]
    if op not in ('==', '!=', '<=', '>=', '<', '>'):
        return TRUTH_VALUE_UNKNOWN
    thing = contains_int_or_tuple_of_ints(expr.operands[1])
    if thing is None:
        return TRUTH_VALUE_UNKNOWN
    index = contains_sys_version_info(expr.operands[0])
    if isinstance(index, int) and isinstance(thing, int):
        # sys.version_info[i] <compare_op> k
        if 0 <= index <= 1:
            return fixed_comparison(pyversion[index], op, thing)
        else:
            return TRUTH_VALUE_UNKNOWN
    elif isinstance(index, tuple) and isinstance(thing, tuple):
        lo, hi = index
        if lo is None:
            lo = 0
        if hi is None:
            hi = 2
        if 0 <= lo < hi <= 2:
            val = pyversion[lo:hi]
            if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='):
                return fixed_comparison(val, op, thing)
    return TRUTH_VALUE_UNKNOWN


def consider_sys_platform(expr: Expression, platform: str) -> int:
    """Consider whether expr is a comparison involving sys.platform.

    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
    """
    # Cases supported:
    # - sys.platform == 'posix'
    # - sys.platform != 'win32'
    # - sys.platform.startswith('win')
    if isinstance(expr, ComparisonExpr):
        # Let's not yet support chained comparisons.
        if len(expr.operators) > 1:
            return TRUTH_VALUE_UNKNOWN
        op = expr.operators[0]
        if op not in ('==', '!='):
            return TRUTH_VALUE_UNKNOWN
        if not is_sys_attr(expr.operands[0], 'platform'):
            return TRUTH_VALUE_UNKNOWN
        right = expr.operands[1]
        if not isinstance(right, (StrExpr, UnicodeExpr)):
            return TRUTH_VALUE_UNKNOWN
        return fixed_comparison(platform, op, right.value)
    elif isinstance(expr, CallExpr):
        if not isinstance(expr.callee, MemberExpr):
            return TRUTH_VALUE_UNKNOWN
        if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)):
            return TRUTH_VALUE_UNKNOWN
        if not is_sys_attr(expr.callee.expr, 'platform'):
            return TRUTH_VALUE_UNKNOWN
        if expr.callee.name != 'startswith':
            return TRUTH_VALUE_UNKNOWN
        if platform.startswith(expr.args[0].value):
            return ALWAYS_TRUE
        else:
            return ALWAYS_FALSE
    else:
        return TRUTH_VALUE_UNKNOWN


Targ = TypeVar('Targ', int, str, Tuple[int, ...])


def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
    rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
    if op == '==':
        return rmap[left == right]
    if op == '!=':
        return rmap[left != right]
    if op == '<=':
        return rmap[left <= right]
    if op == '>=':
        return rmap[left >= right]
    if op == '<':
        return rmap[left < right]
    if op == '>':
        return rmap[left > right]
    return TRUTH_VALUE_UNKNOWN


def contains_int_or_tuple_of_ints(expr: Expression
                                  ) -> Union[None, int, Tuple[int], Tuple[int, ...]]:
    if isinstance(expr, IntExpr):
        return expr.value
    if isinstance(expr, TupleExpr):
        if literal(expr) == LITERAL_YES:
            thing = []
            for x in expr.items:
                if not isinstance(x, IntExpr):
                    return None
                thing.append(x.value)
            return tuple(thing)
    return None


def contains_sys_version_info(expr: Expression
                              ) -> Union[None, int, Tuple[Optional[int], Optional[int]]]:
    if is_sys_attr(expr, 'version_info'):
        return (None, None)  # Same as sys.version_info[:]
    if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'):
        index = expr.index
        if isinstance(index, IntExpr):
            return index.value
        if isinstance(index, SliceExpr):
            if index.stride is not None:
                if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
                    return None
            begin = end = None
            if index.begin_index is not None:
                if not isinstance(index.begin_index, IntExpr):
                    return None
                begin = index.begin_index.value
            if index.end_index is not None:
                if not isinstance(index.end_index, IntExpr):
                    return None
                end = index.end_index.value
            return (begin, end)
    return None


def is_sys_attr(expr: Expression, name: str) -> bool:
    # TODO: This currently doesn't work with code like this:
    # - import sys as _sys
    # - from sys import version_info
    if isinstance(expr, MemberExpr) and expr.name == name:
        if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys':
            # TODO: Guard against a local named sys, etc.
            # (Though later passes will still do most checking.)
            return True
    return False


def mark_block_unreachable(block: Block) -> None:
    block.is_unreachable = True
    block.accept(MarkImportsUnreachableVisitor())


class MarkImportsUnreachableVisitor(TraverserVisitor):
    """Visitor that flags all imports nested within a node as unreachable."""

    def visit_import(self, node: Import) -> None:
        node.is_unreachable = True

    def visit_import_from(self, node: ImportFrom) -> None:
        node.is_unreachable = True

    def visit_import_all(self, node: ImportAll) -> None:
        node.is_unreachable = True


def mark_block_mypy_only(block: Block) -> None:
    block.accept(MarkImportsMypyOnlyVisitor())


class MarkImportsMypyOnlyVisitor(TraverserVisitor):
    """Visitor that sets is_mypy_only (which affects priority)."""

    def visit_import(self, node: Import) -> None:
        node.is_mypy_only = True

    def visit_import_from(self, node: ImportFrom) -> None:
        node.is_mypy_only = True

    def visit_import_all(self, node: ImportAll) -> None:
        node.is_mypy_only = True
